{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../scripts/')\n",
    "from puddle_world import *\n",
    "import itertools \n",
    "import collections \n",
    "from copy import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DynamicProgramming: \n",
    "    def __init__(self, widths, goal, puddles, time_interval, sampling_num, \\\n",
    "                 puddle_coef=100.0, lowerleft=np.array([-4, -4]).T, upperright=np.array([4, 4]).T): \n",
    "        self.pose_min = np.r_[lowerleft, 0]\n",
    "        self.pose_max = np.r_[upperright, math.pi*2]\n",
    "        self.widths = widths\n",
    "        self.goal = goal\n",
    "        \n",
    "        self.index_nums = ((self.pose_max - self.pose_min)/self.widths).astype(int)\n",
    "        nx, ny, nt = self.index_nums\n",
    "        self.indexes = list(itertools.product(range(nx), range(ny), range(nt)))\n",
    "        \n",
    "        self.value_function, self.final_state_flags =  self.init_value_function() \n",
    "        self.policy = self.init_policy()\n",
    "        self.actions = list(set([tuple(self.policy[i]) for i in self.indexes]))\n",
    "        \n",
    "        self.state_transition_probs = self.init_state_transition_probs(time_interval, sampling_num)\n",
    "        self.depths = self.depth_means(puddles, sampling_num)\n",
    "        \n",
    "        self.time_interval = time_interval\n",
    "        self.puddle_coef = puddle_coef\n",
    "        \n",
    "    def value_iteration_sweep(self): #追加\n",
    "        max_delta = 0.0\n",
    "        for index in self.indexes:\n",
    "            if not self.final_state_flags[index]:\n",
    "                max_q = -1e100\n",
    "                max_a = None\n",
    "                qs = [self.action_value(a, index) for a in self.actions] #全行動の行動価値を計算\n",
    "                max_q = max(qs)                               #最大の行動価値\n",
    "                max_a = self.actions[np.argmax(qs)]   #最大の行動価値を与える行動\n",
    "                \n",
    "                delta = abs(self.value_function[index] - max_q)            #変化量\n",
    "                max_delta = delta if delta > max_delta else max_delta #スイープ中で最大の変化量の更新\n",
    "                \n",
    "                self.value_function[index] = max_q      #価値の更新\n",
    "                self.policy[index] = np.array(max_a).T  #方策の更新\n",
    "            \n",
    "        return max_delta        \n",
    "        \n",
    "    def policy_evaluation_sweep(self):   \n",
    "        max_delta = 0.0\n",
    "        for index in self.indexes:\n",
    "            if not self.final_state_flags[index]:\n",
    "                q = self.action_value(tuple(self.policy[index]), index)\n",
    "                \n",
    "                delta = abs(self.value_function[index] - q)\n",
    "                max_delta = delta if delta > max_delta else max_delta\n",
    "                \n",
    "                self.value_function[index] = q\n",
    "            \n",
    "        return max_delta\n",
    "    \n",
    "    def action_value(self, action, index, out_penalty=True): \n",
    "        value = 0.0\n",
    "        for delta, prob in self.state_transition_probs[(action, index[2])]: \n",
    "            after, out_reward = self.out_correction(np.array(index).T + delta)\n",
    "            after = tuple(after)\n",
    "            reward = - self.time_interval * self.depths[(after[0], after[1])] * self.puddle_coef - self.time_interval + out_reward*out_penalty\n",
    "            value += (self.value_function[after] + reward) * prob\n",
    "\n",
    "        return value\n",
    "            \n",
    "    def out_correction(self, index): #変更\n",
    "        out_reward = 0.0\n",
    "        index[2] = (index[2] + self.index_nums[2])%self.index_nums[2] #方角の処理\n",
    "        \n",
    "        for i in range(2):\n",
    "            if index[i] < 0:\n",
    "                index[i] = 0\n",
    "                out_reward = -1e100\n",
    "            elif index[i] >= self.index_nums[i]:\n",
    "                index[i] = self.index_nums[i]-1\n",
    "                out_reward = -1e100\n",
    "                \n",
    "        return index, out_reward\n",
    "        \n",
    "    def depth_means(self, puddles, sampling_num):\n",
    "        ###セルの中の座標を均等にsampling_num**2点サンプリング###\n",
    "        dx = np.linspace(0, self.widths[0], sampling_num) \n",
    "        dy = np.linspace(0, self.widths[1], sampling_num)\n",
    "        samples = list(itertools.product(dx, dy))\n",
    "        \n",
    "        tmp = np.zeros(self.index_nums[0:2]) #深さの合計が計算されて入る\n",
    "        for xy in itertools.product(range(self.index_nums[0]), range(self.index_nums[1])):\n",
    "            for s in samples:\n",
    "                pose = self.pose_min + self.widths*np.array([xy[0], xy[1], 0]).T + np.array([s[0], s[1], 0]).T #セルの中心の座標\n",
    "                for p in puddles:\n",
    "                    tmp[xy] += p.depth*p.inside(pose) #深さに水たまりの中か否か（1 or 0）をかけて足す\n",
    "                        \n",
    "            tmp[xy] /= sampling_num**2 #深さの合計から平均値に変換\n",
    "                       \n",
    "        return tmp\n",
    "    \n",
    "    def init_state_transition_probs(self, time_interval, sampling_num):\n",
    "        ###セルの中の座標を均等にsampling_num**3点サンプリング###\n",
    "        dx = np.linspace(0.001, self.widths[0]*0.999, sampling_num) #隣のセルにはみ出さないように端を避ける\n",
    "        dy = np.linspace(0.001, self.widths[1]*0.999, sampling_num)\n",
    "        dt = np.linspace(0.001, self.widths[2]*0.999, sampling_num)\n",
    "        samples = list(itertools.product(dx, dy, dt))\n",
    "        \n",
    "        ###各行動、各方角でサンプリングした点を移動してインデックスの増分を記録###\n",
    "        tmp = {}\n",
    "        for a in self.actions:\n",
    "            for i_t in range(self.index_nums[2]):\n",
    "                transitions = []\n",
    "                for s in samples:\n",
    "                    before = np.array([s[0], s[1], s[2] + i_t*self.widths[2]]).T + self.pose_min  #遷移前の姿勢\n",
    "                    before_index = np.array([0, 0, i_t]).T                                                      #遷移前のインデックス\n",
    "                \n",
    "                    after = IdealRobot.state_transition(a[0], a[1], time_interval, before)   #遷移後の姿勢\n",
    "                    after_index = np.floor((after - self.pose_min)/self.widths).astype(int)   #遷移後のインデックス\n",
    "                    \n",
    "                    transitions.append(after_index - before_index)                                  #インデックスの差分を追加\n",
    "                    \n",
    "                unique, count = np.unique(transitions, axis=0, return_counts=True)   #集計（どのセルへの遷移が何回か）\n",
    "                probs = [c/sampling_num**3 for c in count]                   #サンプル数で割って確率にする\n",
    "                tmp[a,i_t] = list(zip(unique, probs))\n",
    "                \n",
    "        return tmp\n",
    "        \n",
    "    def init_policy(self):\n",
    "        tmp = np.zeros(np.r_[self.index_nums,2]) #制御出力が2次元なので、配列の次元を4次元に\n",
    "        for index in self.indexes:\n",
    "            center = self.pose_min + self.widths*(np.array(index).T + 0.5)  #セルの中心の座標\n",
    "            tmp[index] = PuddleIgnoreAgent.policy(center, self.goal)\n",
    "            \n",
    "        return tmp\n",
    "        \n",
    "    def init_value_function(self): \n",
    "        v = np.empty(self.index_nums) #全離散状態を要素に持つ配列を作成\n",
    "        f = np.zeros(self.index_nums) \n",
    "        \n",
    "        for index in self.indexes:\n",
    "            f[index] = self.final_state(np.array(index).T)\n",
    "            v[index] = self.goal.value if f[index] else -100.0\n",
    "                \n",
    "        return v, f\n",
    "        \n",
    "    def final_state(self, index):\n",
    "        x_min, y_min, _ = self.pose_min + self.widths*index          #xy平面で左下の座標\n",
    "        x_max, y_max, _ = self.pose_min + self.widths*(index + 1) #右上の座標（斜め上の離散状態の左下の座標）\n",
    "        \n",
    "        corners = [[x_min, y_min, _], [x_min, y_max, _], [x_max, y_min, _], [x_max, y_max, _] ] #4隅の座標\n",
    "        return all([self.goal.inside(np.array(c).T) for c in corners ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
